package com.sql.mysql.sharding.plugin;

import java.util.Properties;

import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;

import com.sql.mysql.sharding.annotation.ShardingKeyObject;
import com.sql.mysql.sharding.datasource.ShardingDataSource;
import com.sql.mysql.sharding.rewrite.ReWriteEngine;
import com.sql.mysql.sharding.threadlocal.ShardingThreadLocal;
import com.sql.mysql.sharding.utils.NumberUtils;

import lombok.extern.slf4j.Slf4j;

@Intercepts(value = {
		// @Signature(type = StatementHandler.class, method = "batch", args = {
		// Statement.class }),
		// @Signature(type = StatementHandler.class, method = "update", args = {
		// Statement.class }),
		// @Signature(type = StatementHandler.class, method = "query", args = {
		// Statement.class, ResultHandler.class }),
		// @Signature(type = StatementHandler.class, method = "queryCursor",
		// args = { Statement.class })

})
@Slf4j
public class ShardStatementHandlerInterceptor implements Interceptor {

	@SuppressWarnings("unused")
	@Override
	public Object intercept(Invocation invocation) throws Throwable {
		long begin = System.currentTimeMillis();
		Object result = invocation.proceed();
		long end = System.currentTimeMillis();
		// ThreadLocalHelper.SqlExecuteTimeThreadLocal.set(end - begin);
		return result;
	}

	private void checkTransactionConsistency(long currentShardId) throws Exception {
		if (true == ShardingThreadLocal.AutoCommitThreadLocal.get()) {
			log.info("本次是自动提交,不需要校验");
			return;
		}
		log.info("本次是手动提交");
		if (null == ShardingThreadLocal.LastShardingKeyIDThreadLocal.get()) {
			log.info("上一次没有值,不需要进行一致性校验,但是需要保存当前分片值,供下一次比对使用");
			ShardingThreadLocal.LastShardingKeyIDThreadLocal.set((Long) currentShardId);
			return;
		}
		log.info("上一次有值,进行比对");
		long minus = currentShardId - ShardingThreadLocal.LastShardingKeyIDThreadLocal.get();
		if (0 != minus) {
			log.error("not the same {},{}", currentShardId, ShardingThreadLocal.LastShardingKeyIDThreadLocal.get());
			throw new Exception("currentShardId: " + currentShardId + " is not the same as last: "
					+ ShardingThreadLocal.LastShardingKeyIDThreadLocal.get());
		} else {
			log.info("succeed, shardId are all consistent !!! bingo");
		}
	}

	@Override
	public Object plugin(Object target) {
		// 仅仅处理RoutingStatementHandler
		// 仅仅为了在获取Connection之前获得sql模板
		// 执行到这里的时候，sql已经之前生成了,所以，尽管使用
		// log.info("StatementHandlerInterceptor.plugin--->" + target);
		if (target instanceof StatementHandler) {
			// 这里进行sql的改写,配合后面的getConnection操作
			StatementHandler statementHandler = (StatementHandler) target;
			BoundSql boundSql = statementHandler.getBoundSql();
			// 保存之前先改写
			// 如果改写成功就绑定这个boundSql
			ShardingKeyObject shardingKeyObject = null;
			MetaObject metaObject = SystemMetaObject.forObject(boundSql);
			try {
				shardingKeyObject = ShardingThreadLocal.ShardingKeyObjectThreadLocal.get();
				ReWriteEngine.rewrite(metaObject, shardingKeyObject, ShardingDataSource.savedShardGroups, boundSql);
				ShardingThreadLocal.BoundSqlThreadLocal.set(boundSql);
				// 非常重要
				// 进行事务的一致性比对
				checkTransactionConsistency(NumberUtils.Object2Long(shardingKeyObject.getValue()));
			} catch (Exception e) {
				// 如果改写失败就不绑定,sql也设置为null,mybatis执行报错,后面数据源获取连接就会报错
				log.error("error happened {}", e);
				metaObject.setValue("sql", null);
				ShardingThreadLocal.BoundSqlThreadLocal.set(null);
			}
		}
		// 结束
		return target;
	}

	@Override
	public void setProperties(Properties properties) {

	}

}
